{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6bb51cd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append(os.path.abspath('../../code'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcee14a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from pathlib import Path\n",
    "\n",
    "import regex\n",
    "from collections import Counter\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(formatter={'float': lambda x: \"{0:0.3f}\".format(x)})\n",
    "import pandas as pd\n",
    "pd.set_option('display.max_columns', 15)\n",
    "pd.set_option('display.width', 320)\n",
    "\n",
    "from utils.io import read_label_config\n",
    "from utils.corpus import JsonlinesAnnotationsCorpus\n",
    "from utils.bsc_model import BSCModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "441cec08",
   "metadata": {},
   "outputs": [],
   "source": [
    "from types import SimpleNamespace\n",
    "\n",
    "args = SimpleNamespace()\n",
    "args.label_config_file = '../../data/annotation/doccano_label_config.json'\n",
    "args.input_file = '../../data/annotation/parsed/uk-commons_annotations.jsonl'\n",
    "\n",
    "args.data_path = '../../data/annotation/annotations/'\n",
    "args.data_folder_pattern = 'uk-commons'\n",
    "args.data_file_format = 'csv'\n",
    "\n",
    "args.output_file = '../../data/annotation/labeled/uk-commons_all_labeled.jsonl'\n",
    "args.overwrite_output = False\n",
    "\n",
    "args.verbose = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9940ee7",
   "metadata": {},
   "source": [
    "## Read the annotations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0485a7f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "job\n",
      "group-mentions-annotation-uk-commons-round-01    1575\n",
      "Name: count, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "fps = [str(fp) for fp in Path(args.data_path).glob(f'*{args.data_folder_pattern}*/*.{args.data_file_format}')]\n",
    "# read metadata\n",
    "metadata = pd.concat({fp.split('/')[-1].replace('_ids.csv', ''): pd.read_csv(fp) for fp in fps}, axis=0)\n",
    "metadata['job'] = metadata.index.get_level_values(0)\n",
    "metadata.reset_index(drop=True, inplace=True)\n",
    "print(metadata.job.value_counts())\n",
    "jobs = metadata.job.unique().tolist()\n",
    "jobs.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c9b210fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1573/1573 [00:02<00:00, 653.69it/s]\n"
     ]
    }
   ],
   "source": [
    "cat2code = read_label_config(args.label_config_file)\n",
    "acorp = JsonlinesAnnotationsCorpus(cat2code)\n",
    "\n",
    "acorp.load_from_jsonlines(args.input_file, verbose=args.verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "56fd658e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# docs =  1573\n",
      "# gold items = 121\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'O': 0,\n",
       " 'I-social group': 1,\n",
       " 'B-social group': 7,\n",
       " 'I-political group': 2,\n",
       " 'B-political group': 8,\n",
       " 'I-political institution': 3,\n",
       " 'B-political institution': 9,\n",
       " 'I-organization, public institution, or collective actor': 4,\n",
       " 'B-organization, public institution, or collective actor': 10,\n",
       " 'I-implicit social group reference': 5,\n",
       " 'B-implicit social group reference': 11,\n",
       " 'I-unsure': 6,\n",
       " 'B-unsure': 12}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_docs = acorp.ndocs\n",
    "n_gold_labeled = len([1 for doc in acorp.docs if doc.n_labels > 0])\n",
    "print('# docs = ', n_docs)\n",
    "print('# gold items =', n_gold_labeled)\n",
    "acorp.label_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "60a595f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "b3d0d7b0926cb70b858ced44aab87880\n",
      "2044106345e3988ce759190853dc980c\n"
     ]
    }
   ],
   "source": [
    "# ensure that 'unsure' never in GOLD annotations\n",
    "for doc in acorp.docs:\n",
    "    if doc.n_labels > 0 and cat2code['B-unsure'] in doc.labels['GOLD']:\n",
    "        print(doc.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4d6237ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(acorp.ndocs):\n",
    "    if acorp.docs[i].n_labels > 0:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25ae0a6e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'GOLD': array([ 0,  0,  0,  0,  0,  0,  0, 10,  4,  4,  4,  0,  0,  0,  0,  0,  0,\n",
       "         0,  0,  0,  0,  0,  0,  7,  0,  7,  0])}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "acorp.docs[i].labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "278841e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop 'unsure' labels\n",
    "def recode_labels(codes):\n",
    "    out = codes.copy()\n",
    "    # set \"unsure\" to outside\n",
    "    out[codes ==  6] = 0\n",
    "    out[codes == 12] = 0\n",
    "    # reset all others\n",
    "    out[out>6] -= 1\n",
    "    return(out)\n",
    "\n",
    "for i in range(acorp.ndocs):\n",
    "    for j in range(acorp.docs[i].n_annotations):\n",
    "        acorp.docs[i].annotations[ acorp.docs[i].annotators[j] ] = recode_labels(acorp.docs[i].annotations[ acorp.docs[i].annotators[j] ])\n",
    "        for k in acorp.docs[i].labels.keys():\n",
    "            acorp.docs[i].labels[k] = recode_labels(acorp.docs[i].labels[k])\n",
    "    if acorp.docs[i].n_labels > 0:\n",
    "        acorp.docs[i].labels['GOLD'] = recode_labels(acorp.docs[i].labels['GOLD'])\n",
    "\n",
    "\n",
    "idx = acorp.label_map.pop('B-unsure')\n",
    "acorp.label_map_inv.pop(idx)\n",
    "idx = acorp.label_map.pop('I-unsure')\n",
    "acorp.label_map_inv.pop(idx)\n",
    "\n",
    "# [print(c, 't', l) for c, l in acorp.label_map_inv.items()]\n",
    "for c in range(7,12): acorp.label_map[acorp.label_map_inv[c]] -= 1\n",
    "# [print(c, 't', l) for l, c in acorp.label_map.items()]\n",
    "\n",
    "acorp.label_map_inv = {c: l for l, c in acorp.label_map.items()}\n",
    "# [print(c, 't', l) for c, l in acorp.label_map_inv.items()]\n",
    "\n",
    "len(acorp.label_map) == len(acorp.label_map_inv)\n",
    "\n",
    "acorp.inside_labels = list(range(1, 6))\n",
    "acorp.beginning_labels = list(range(6, 11))\n",
    "\n",
    "acorp.ndocs = len(acorp.docs)\n",
    "acorp.doc_ids = [doc.id for doc in acorp.docs]\n",
    "acorp.doc_id2idx = {doc.id: i for i, doc in enumerate(acorp.docs)}\n",
    "acorp.doc_idx2id = {i: doc.id for i, doc in enumerate(acorp.docs)}\n",
    "\n",
    "acorp.annotator_label_counts = acorp._count_annotator_labels()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "97d23e73",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Counter({0: 2166, 1: 250, 3: 136, 4: 95, 6: 68, 2: 41, 7: 36, 5: 29, 8: 12})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# count label types in GOlD\n",
    "# ensure that 'unsure' never in GOLD annotations\n",
    "gold_types = Counter()\n",
    "for doc in acorp.docs:\n",
    "    if 'GOLD' in doc.labels:\n",
    "        gold_types.update(doc.labels['GOLD'].tolist())\n",
    "gold_types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8b9c3ace",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>emarie</th>\n",
       "      <th>sjasmin</th>\n",
       "      <th>gold</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>21798</td>\n",
       "      <td>22242</td>\n",
       "      <td>2166.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>802</td>\n",
       "      <td>682</td>\n",
       "      <td>250.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>291</td>\n",
       "      <td>326</td>\n",
       "      <td>68.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>193</td>\n",
       "      <td>97</td>\n",
       "      <td>41.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>97</td>\n",
       "      <td>62</td>\n",
       "      <td>36.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>599</td>\n",
       "      <td>561</td>\n",
       "      <td>136.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>394</td>\n",
       "      <td>394</td>\n",
       "      <td>12.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>440</td>\n",
       "      <td>301</td>\n",
       "      <td>95.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>213</td>\n",
       "      <td>207</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>94</td>\n",
       "      <td>49</td>\n",
       "      <td>29.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>73</td>\n",
       "      <td>69</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    emarie  sjasmin    gold\n",
       "0    21798    22242  2166.0\n",
       "1      802      682   250.0\n",
       "6      291      326    68.0\n",
       "2      193       97    41.0\n",
       "7       97       62    36.0\n",
       "3      599      561   136.0\n",
       "8      394      394    12.0\n",
       "4      440      301    95.0\n",
       "9      213      207     NaN\n",
       "5       94       49    29.0\n",
       "10      73       69     NaN"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "annotations = acorp.annotator_label_counts\n",
    "gold = Counter()\n",
    "for labs in [doc.labels['GOLD'].tolist() for doc in acorp.docs if 'GOLD' in doc.labels]:\n",
    "    gold.update(labs)\n",
    "pd.DataFrame(annotations).join(pd.DataFrame({'gold': dict(gold)})).loc[acorp.label_map.values()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b31fb2de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9757688723205965\n"
     ]
    }
   ],
   "source": [
    "# define custom token cleaning function\n",
    "def clean_tokens(x):\n",
    "\n",
    "    # insert special characters\n",
    "    if regex.match(r\"^\\p{Sc}\", x):\n",
    "        return(\"<MONEY>\")\n",
    "\n",
    "    if regex.match(r\"^\\d+([.,/-]\\d+)*\\w*$\", x):\n",
    "        return(\"<DIGITS>\")\n",
    "\n",
    "    return(x)\n",
    "\n",
    "# clean docs' tokens (in place) and collect results in Counter object\n",
    "cleaned_vocab = Counter()\n",
    "for i in range(acorp.ndocs):\n",
    "    cleaned_vocab.update(acorp.docs[i].clean_tokens(fun=clean_tokens))\n",
    "\n",
    "# reduces vocab size somewhat!\n",
    "print(len(cleaned_vocab)/len(acorp.vocab))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8acfcb88",
   "metadata": {},
   "source": [
    "## Annotation aggregation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b86013f",
   "metadata": {},
   "source": [
    "### Prepare the Baysian sequene combination (BSC) sequence annotation model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "35f71c33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parallel can run 10 jobs simultaneously, with 10 cores\n"
     ]
    }
   ],
   "source": [
    "amodel = BSCModel(acorp, max_iter=30, gold_labels='GOLD', verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4616fe2f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# label classes: 11\n",
      "# annotators:    2\n",
      "# docs: 1573\n",
      "# tokens: 37500\n",
      "# docs with gold labels: 121\n",
      "# tokens with gold labels: 667\n"
     ]
    }
   ],
   "source": [
    "print('# label classes:', amodel.num_classes)\n",
    "print('# annotators:   ', amodel.num_annotators)\n",
    "print('# docs:', amodel.num_docs)\n",
    "print('# tokens:', amodel.num_tokens)\n",
    "print('# docs with gold labels:', (amodel.gold[amodel.doc_start == 1] > -1).sum())\n",
    "print('# tokens with gold labels:', (amodel.gold > 0).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "e1b3fab4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# inspect the alpha prior\n",
    "alpha0 = amodel.model.A.alpha0.copy()\n",
    "# note: default prior belief is that annotators assign correct label 2 out of 3 times\n",
    "alpha0[0,0]/alpha0[0,1:].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3a568c13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                    =>O  =>I-soc  =>I-pol  =>I-pol  =>I-org  =>I-imp  =>B-soc  =>B-pol  =>B-pol  =>B-org  =>B-imp\n",
      "O                                                   6.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0      1.0\n",
      "I-social group                                      1.0      1.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0\n",
      "I-political group                                   1.0      0.0      1.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0\n",
      "I-political institution                             1.0      0.0      0.0      1.0      0.0      0.0      1.0      1.0      0.0      1.0      1.0\n",
      "I-organization, public institution, or collecti...  1.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0      0.0      1.0\n",
      "I-implicit social group reference                   1.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0      1.0      0.0\n",
      "B-social group                                      1.0      5.0      0.0      0.0      0.0      0.0      0.0      1.0      1.0      1.0      1.0\n",
      "B-political group                                   1.0      0.0      5.0      0.0      0.0      0.0      1.0      0.0      1.0      1.0      1.0\n",
      "B-political institution                             1.0      0.0      0.0      5.0      0.0      0.0      1.0      1.0      0.0      1.0      1.0\n",
      "B-organization, public institution, or collecti...  1.0      0.0      0.0      0.0      5.0      0.0      1.0      1.0      1.0      0.0      1.0\n",
      "B-implicit social group reference                   1.0      0.0      0.0      0.0      0.0      5.0      1.0      1.0      1.0      1.0      0.0\n"
     ]
    }
   ],
   "source": [
    "# inspect transitions prior\n",
    "new_beta0 = amodel.model.LM.beta0.copy()\n",
    "cats = [k for k, v in sorted(acorp.label_map.items(), key=lambda item: item[1])]\n",
    "tmp = pd.DataFrame(new_beta0.round(2), index = cats, columns = ['=>'+c[:5]  for c in cats])\n",
    "# note: read this by rows and columns indicate label categories 0...12 and the cell (i, j) indicates the prior from\n",
    "#       probability that the label in column j follows the label in row i\n",
    "# for examples:\n",
    "# after the \"outside\" label, only itself and B-* labels are allowed\n",
    "tmp.iloc[[0]]\n",
    "# after the \"I-social group\" label, only itself, O, or the other B-* labels are allowed\n",
    "tmp.iloc[[1]]\n",
    "new_beta0[1, 6] = 1e-12\n",
    "# after the \"B-social group\" label, O, \"I-social group\", or any of the other B-* labels are allowed\n",
    "tmp.iloc[[6]]\n",
    "new_beta0[6, 6] = 1e-12\n",
    "\n",
    "# apply logic to each type\n",
    "for c in range(2, 6):\n",
    "    new_beta0[c, c+5] = 1e-12\n",
    "    new_beta0[c+5, c+5] = 1e-12\n",
    "\n",
    "# verify\n",
    "print(pd.DataFrame(new_beta0.round(2), index = cats, columns = ['=>'+c[:5]  for c in cats]))\n",
    "\n",
    "# reset\n",
    "amodel.reset_label_transitions_prior(new_beta0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cc339a8",
   "metadata": {},
   "source": [
    "### Fit the BSC model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ba4f1ebe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BSC: run() called with annotation matrix with shape = (37500, 2)\n",
      "BC iteration 0 in progress\n",
      "BC iteration 0: computed label probabilities\n",
      "BC iteration 0: updated label model\n",
      "BC iteration 0: updated worker models\n",
      "BC iteration 1 in progress\n",
      "BAC iteration 1: completed forward pass\n",
      "BAC iteration 1: completed backward pass\n",
      "BC iteration 1: computed label probabilities\n",
      "BC iteration 1: updated label model\n",
      "BC iteration 1: updated worker models\n",
      "Computing LB=-759947.2036: label model and labels=-76513.0000, annotator model=-457.2500, features=-682976.9301\n",
      "BSC: max. difference at iteration 2: inf\n",
      "BC iteration 2 in progress\n",
      "BAC iteration 2: completed forward pass\n",
      "BAC iteration 2: completed backward pass\n",
      "BC iteration 2: computed label probabilities\n",
      "BC iteration 2: updated label model\n",
      "BC iteration 2: updated worker models\n",
      "Computing LB=-741708.9618: label model and labels=-59333.7656, annotator model=-329.5000, features=-682045.7118\n",
      "BSC: max. difference at iteration 3: 18238.24182\n",
      "BC iteration 3 in progress\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/hlicht/miniforge3/envs/group_mention_detection/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py:752: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BAC iteration 3: completed forward pass\n",
      "BAC iteration 3: completed backward pass\n",
      "BC iteration 3: computed label probabilities\n",
      "BC iteration 3: updated label model\n",
      "BC iteration 3: updated worker models\n",
      "Computing LB=-739811.8993: label model and labels=-57484.5938, annotator model=-303.3750, features=-682023.9696\n",
      "BSC: max. difference at iteration 4: 1897.06244\n",
      "BC iteration 4 in progress\n",
      "BAC iteration 4: completed forward pass\n",
      "BAC iteration 4: completed backward pass\n",
      "BC iteration 4: computed label probabilities\n",
      "BC iteration 4: updated label model\n",
      "BC iteration 4: updated worker models\n",
      "Computing LB=-739492.9847: label model and labels=-57199.0859, annotator model=-296.8750, features=-681996.9613\n",
      "BSC: max. difference at iteration 5: 318.91457\n",
      "BC iteration 5 in progress\n",
      "BAC iteration 5: completed forward pass\n",
      "BAC iteration 5: completed backward pass\n",
      "BC iteration 5: computed label probabilities\n",
      "BC iteration 5: updated label model\n",
      "BC iteration 5: updated worker models\n",
      "Computing LB=-739332.2476: label model and labels=-57080.4453, annotator model=-292.8750, features=-681958.8960\n",
      "BSC: max. difference at iteration 6: 160.73718\n",
      "BC iteration 6 in progress\n",
      "BAC iteration 6: completed forward pass\n",
      "BAC iteration 6: completed backward pass\n",
      "BC iteration 6: computed label probabilities\n",
      "BC iteration 6: updated label model\n",
      "BC iteration 6: updated worker models\n",
      "Computing LB=-739207.5568: label model and labels=-56987.2500, annotator model=-290.2500, features=-681930.0881\n",
      "BSC: max. difference at iteration 7: 124.69077\n",
      "BC iteration 7 in progress\n",
      "BAC iteration 7: completed forward pass\n",
      "BAC iteration 7: completed backward pass\n",
      "BC iteration 7: computed label probabilities\n",
      "BC iteration 7: updated label model\n",
      "BC iteration 7: updated worker models\n",
      "Computing LB=-739114.0903: label model and labels=-56910.7500, annotator model=-287.3750, features=-681916.0121\n",
      "BSC: max. difference at iteration 8: 93.46655\n",
      "BC iteration 8 in progress\n",
      "BAC iteration 8: completed forward pass\n",
      "BAC iteration 8: completed backward pass\n",
      "BC iteration 8: computed label probabilities\n",
      "BC iteration 8: updated label model\n",
      "BC iteration 8: updated worker models\n",
      "Computing LB=-739063.8024: label model and labels=-56866.7578, annotator model=-285.3750, features=-681911.7009\n",
      "BSC: max. difference at iteration 9: 50.28783\n",
      "BC iteration 9 in progress\n",
      "BAC iteration 9: completed forward pass\n",
      "BAC iteration 9: completed backward pass\n",
      "BC iteration 9: computed label probabilities\n",
      "BC iteration 9: updated label model\n",
      "BC iteration 9: updated worker models\n",
      "Computing LB=-739040.5841: label model and labels=-56847.3047, annotator model=-284.1250, features=-681909.1388\n",
      "BSC: max. difference at iteration 10: 23.21829\n",
      "BC iteration 10 in progress\n",
      "BAC iteration 10: completed forward pass\n",
      "BAC iteration 10: completed backward pass\n",
      "BC iteration 10: computed label probabilities\n",
      "BC iteration 10: updated label model\n",
      "BC iteration 10: updated worker models\n",
      "Computing LB=-739018.5257: label model and labels=-56834.1797, annotator model=-283.0000, features=-681901.3850\n",
      "BSC: max. difference at iteration 11: 22.05846\n",
      "BC iteration 11 in progress\n",
      "BAC iteration 11: completed forward pass\n",
      "BAC iteration 11: completed backward pass\n",
      "BC iteration 11: computed label probabilities\n",
      "BC iteration 11: updated label model\n",
      "BC iteration 11: updated worker models\n",
      "Computing LB=-738986.4529: label model and labels=-56808.4219, annotator model=-282.1250, features=-681895.8983\n",
      "BSC: max. difference at iteration 12: 32.07273\n",
      "BC iteration 12 in progress\n",
      "BAC iteration 12: completed forward pass\n",
      "BAC iteration 12: completed backward pass\n",
      "BC iteration 12: computed label probabilities\n",
      "BC iteration 12: updated label model\n",
      "BC iteration 12: updated worker models\n",
      "Computing LB=-738970.8113: label model and labels=-56796.9141, annotator model=-281.3750, features=-681892.5144\n",
      "BSC: max. difference at iteration 13: 15.64162\n",
      "BC iteration 13 in progress\n",
      "BAC iteration 13: completed forward pass\n",
      "BAC iteration 13: completed backward pass\n",
      "BC iteration 13: computed label probabilities\n",
      "BC iteration 13: updated label model\n",
      "BC iteration 13: updated worker models\n",
      "Computing LB=-738961.1540: label model and labels=-56791.2656, annotator model=-281.1250, features=-681888.7165\n",
      "BSC: max. difference at iteration 14: 9.65727\n",
      "BC iteration 14 in progress\n",
      "BAC iteration 14: completed forward pass\n",
      "BAC iteration 14: completed backward pass\n",
      "BC iteration 14: computed label probabilities\n",
      "BC iteration 14: updated label model\n",
      "BC iteration 14: updated worker models\n",
      "Computing LB=-738946.9792: label model and labels=-56780.7344, annotator model=-280.5000, features=-681885.7214\n",
      "BSC: max. difference at iteration 15: 14.17483\n",
      "BC iteration 15 in progress\n",
      "BAC iteration 15: completed forward pass\n",
      "BAC iteration 15: completed backward pass\n",
      "BC iteration 15: computed label probabilities\n",
      "BC iteration 15: updated label model\n",
      "BC iteration 15: updated worker models\n",
      "Computing LB=-738937.7467: label model and labels=-56773.5781, annotator model=-280.2500, features=-681883.8951\n",
      "BSC: max. difference at iteration 16: 9.23254\n",
      "BC iteration 16 in progress\n",
      "BAC iteration 16: completed forward pass\n",
      "BAC iteration 16: completed backward pass\n",
      "BC iteration 16: computed label probabilities\n",
      "BC iteration 16: updated label model\n",
      "BC iteration 16: updated worker models\n",
      "Computing LB=-738931.9689: label model and labels=-56769.3281, annotator model=-280.2500, features=-681882.3517\n",
      "BSC: max. difference at iteration 17: 5.77776\n",
      "BC iteration 17 in progress\n",
      "BAC iteration 17: completed forward pass\n",
      "BAC iteration 17: completed backward pass\n",
      "BC iteration 17: computed label probabilities\n",
      "BC iteration 17: updated label model\n",
      "BC iteration 17: updated worker models\n",
      "Computing LB=-738927.0514: label model and labels=-56765.8594, annotator model=-280.2500, features=-681880.9967\n",
      "BSC: max. difference at iteration 18: 4.91754\n",
      "BC iteration 18 in progress\n",
      "BAC iteration 18: completed forward pass\n",
      "BAC iteration 18: completed backward pass\n",
      "BC iteration 18: computed label probabilities\n",
      "BC iteration 18: updated label model\n",
      "BC iteration 18: updated worker models\n",
      "Computing LB=-738922.9023: label model and labels=-56762.8594, annotator model=-280.2500, features=-681879.7617\n",
      "BSC: max. difference at iteration 19: 4.14903\n",
      "BC iteration 19 in progress\n",
      "BAC iteration 19: completed forward pass\n",
      "BAC iteration 19: completed backward pass\n",
      "BC iteration 19: computed label probabilities\n",
      "BC iteration 19: updated label model\n",
      "BC iteration 19: updated worker models\n",
      "Computing LB=-738919.0365: label model and labels=-56760.1875, annotator model=-280.2500, features=-681878.5521\n",
      "BSC: max. difference at iteration 20: 3.86589\n",
      "BC iteration 20 in progress\n",
      "BAC iteration 20: completed forward pass\n",
      "BAC iteration 20: completed backward pass\n",
      "BC iteration 20: computed label probabilities\n",
      "BC iteration 20: updated label model\n",
      "BC iteration 20: updated worker models\n",
      "Computing LB=-738915.0719: label model and labels=-56757.5938, annotator model=-280.2500, features=-681877.2360\n",
      "BSC: max. difference at iteration 21: 3.96450\n",
      "BC iteration 21 in progress\n",
      "BAC iteration 21: completed forward pass\n",
      "BAC iteration 21: completed backward pass\n",
      "BC iteration 21: computed label probabilities\n",
      "BC iteration 21: updated label model\n",
      "BC iteration 21: updated worker models\n",
      "Computing LB=-738911.0441: label model and labels=-56754.9297, annotator model=-280.5000, features=-681875.6456\n",
      "BSC: max. difference at iteration 22: 4.02786\n",
      "BC iteration 22 in progress\n",
      "BAC iteration 22: completed forward pass\n",
      "BAC iteration 22: completed backward pass\n",
      "BC iteration 22: computed label probabilities\n",
      "BC iteration 22: updated label model\n",
      "BC iteration 22: updated worker models\n",
      "Computing LB=-738906.0025: label model and labels=-56752.0625, annotator model=-280.2500, features=-681873.6822\n",
      "BSC: max. difference at iteration 23: 5.04156\n",
      "BC iteration 23 in progress\n",
      "BAC iteration 23: completed forward pass\n",
      "BAC iteration 23: completed backward pass\n",
      "BC iteration 23: computed label probabilities\n",
      "BC iteration 23: updated label model\n",
      "BC iteration 23: updated worker models\n",
      "Computing LB=-738900.6760: label model and labels=-56748.7969, annotator model=-280.3750, features=-681871.5432\n",
      "BSC: max. difference at iteration 24: 5.32654\n",
      "BC iteration 24 in progress\n",
      "BAC iteration 24: completed forward pass\n",
      "BAC iteration 24: completed backward pass\n",
      "BC iteration 24: computed label probabilities\n",
      "BC iteration 24: updated label model\n",
      "BC iteration 24: updated worker models\n",
      "Computing LB=-738895.4236: label model and labels=-56745.3516, annotator model=-280.2500, features=-681869.8767\n",
      "BSC: max. difference at iteration 25: 5.25241\n",
      "BC iteration 25 in progress\n",
      "BAC iteration 25: completed forward pass\n",
      "BAC iteration 25: completed backward pass\n",
      "BC iteration 25: computed label probabilities\n",
      "BC iteration 25: updated label model\n",
      "BC iteration 25: updated worker models\n",
      "Computing LB=-738892.1734: label model and labels=-56743.0234, annotator model=-280.2500, features=-681868.9077\n",
      "BSC: max. difference at iteration 26: 3.25021\n",
      "BC iteration 26 in progress\n",
      "BAC iteration 26: completed forward pass\n",
      "BAC iteration 26: completed backward pass\n",
      "BC iteration 26: computed label probabilities\n",
      "BC iteration 26: updated label model\n",
      "BC iteration 26: updated worker models\n",
      "Computing LB=-738890.4397: label model and labels=-56741.9297, annotator model=-280.2500, features=-681868.2678\n",
      "BSC: max. difference at iteration 27: 1.73367\n",
      "BC iteration 27 in progress\n",
      "BAC iteration 27: completed forward pass\n",
      "BAC iteration 27: completed backward pass\n",
      "BC iteration 27: computed label probabilities\n",
      "BC iteration 27: updated label model\n",
      "BC iteration 27: updated worker models\n",
      "Computing LB=-738889.3229: label model and labels=-56741.4219, annotator model=-280.1250, features=-681867.7526\n",
      "BSC: max. difference at iteration 28: 1.11678\n",
      "BC iteration 28 in progress\n",
      "BAC iteration 28: completed forward pass\n",
      "BAC iteration 28: completed backward pass\n",
      "BC iteration 28: computed label probabilities\n",
      "BC iteration 28: updated label model\n",
      "BC iteration 28: updated worker models\n",
      "Computing LB=-738888.7666: label model and labels=-56741.1719, annotator model=-280.2500, features=-681867.2978\n",
      "BSC: max. difference at iteration 29: 0.55636\n",
      "BC iteration 29 in progress\n",
      "BAC iteration 29: completed forward pass\n",
      "BAC iteration 29: completed backward pass\n",
      "BC iteration 29: computed label probabilities\n",
      "BC iteration 29: updated label model\n",
      "BC iteration 29: updated worker models\n",
      "BC iteration 30: computing most likely labels...\n",
      "BAC iteration 30: completed forward pass\n",
      "BAC iteration 30: completed backward pass\n",
      "BC iteration 30: fitting/predicting complete.\n"
     ]
    }
   ],
   "source": [
    "# fit model\n",
    "amodel.fit_predict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "821b1bc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "214.11958279204555\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[inf,\n",
       " 18238.241818634793,\n",
       " 1897.0624387806747,\n",
       " 318.91457307501696,\n",
       " 160.73717962857336,\n",
       " 124.69076562579721,\n",
       " 93.46654740266968,\n",
       " 50.287831012159586,\n",
       " 23.218291213270277,\n",
       " 22.058464309899136,\n",
       " 32.072725362260826,\n",
       " 15.641622876748443,\n",
       " 9.657274924917147,\n",
       " 14.174832683289424,\n",
       " 9.232539518037811,\n",
       " 5.777763915481046,\n",
       " 4.9175397456856444,\n",
       " 4.149025749065913,\n",
       " 3.865889688488096,\n",
       " 3.96450487524271,\n",
       " 4.027864253614098,\n",
       " 5.041564604500309,\n",
       " 5.326535994186997,\n",
       " 5.2524078383576125,\n",
       " 3.2502073486102745,\n",
       " 1.7336729498347268,\n",
       " 1.116782711003907,\n",
       " 0.5563615111168474]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(amodel.runtime_)\n",
    "amodel.model.convergence_history"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "284cfabd",
   "metadata": {},
   "source": [
    "### add sentence metadata to posterior label estimates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e788e816",
   "metadata": {},
   "outputs": [],
   "source": [
    "# add metadata\n",
    "metadata.drop(['split_', 'job'], axis=1, inplace=True)\n",
    "metadata.set_index('uid', drop=True, verify_integrity=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "ef99669e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for doc in amodel.corpus.docs:\n",
    "    doc.metadata = {k: v.item() if isinstance(v, np.int64) else v for k, v in dict(metadata.loc[doc.id]).items() if not pd.isna(v)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3234da91",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ True]), array([1573]))"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.unique(np.array([hasattr(doc, 'metadata') for doc in amodel.corpus.docs]), return_counts=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd9cf9fe",
   "metadata": {},
   "source": [
    "## Write to disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "dbcd0e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _jsonify(self, fields=['id', 'text', 'tokens', 'annotations', 'labels', 'metadata']):\n",
    "    out = {k: self.__dict__[k] for k in fields if hasattr(self, k)}\n",
    "    if 'annotations' in out.keys():\n",
    "        out['annotations'] = {k: v.tolist() for k, v in out['annotations'].items()}\n",
    "    if 'labels' in out.keys():\n",
    "        out['labels'] = {k: v.tolist() for k, v in out['labels'].items()}\n",
    "        if len(out['labels']) == 0:\n",
    "            out['labels'] = None\n",
    "    if 'metadata' in out.keys():\n",
    "        if len(out['metadata']) == 0:\n",
    "            out['metadata'] = None\n",
    "    return json.dumps(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f270ad74",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(args.output_file) or args.overwrite_output:\n",
    "    with open(args.output_file, mode='w', encoding='utf-8') as writer:\n",
    "        for doc in amodel.corpus.docs:\n",
    "            writer.write(_jsonify(doc)+'\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "group_mention_detection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
